{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F \n",
    "import d2lzh as d2l \n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def batch_norm(is_training,x,gamma,beta,moving_mean,moving_var,eps,momentum):\n",
    "    if not is_training:\n",
    "        #　在预测模式下，直接使用传入的moving mean和moving var\n",
    "        x_hat = (x-moving_mean) / torch.sqrt(moving_var + eps)\n",
    "    else:\n",
    "        assert len(x.shape) in (2,4)\n",
    "        if len(x.shape) == 2: # 全连接层，计算特征维度的均值和方差\n",
    "            mean = x.mean(dim=0)\n",
    "            var = ((x -mean) ** 2).mean(dim=0)\n",
    "        else:\n",
    "            # 使用二维卷积，计算通道维度上的均值和方差\n",
    "            mean = x.mean(dim=0,keepdim=True).mean(dim=2,keepdim=True).mean(dim=3,keepdim=True)\n",
    "            var = ((x-mean) ** 2).mean(dim=0,keepdim=True).mean(dim=2,keepdim=True).mean(dim=3,keepdim=True)\n",
    "        x_hat = (x-mean) / torch.sqrt(var+eps)\n",
    "\n",
    "        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n",
    "        moving_var = momentum * moving_var + (1.0  - momentum) * var \n",
    "\n",
    "    y = gamma * x_hat + beta\n",
    "    return y,moving_mean,moving_var\n",
    "\n",
    "class BatchNorm(nn.Module):\n",
    "    def __init__(self,num_features,num_dims):\n",
    "        super(BatchNorm,self).__init__()\n",
    "\n",
    "        if num_dims == 2:\n",
    "            shape = (1,num_features)\n",
    "        else:\n",
    "            shape = (1,num_features,1,1)\n",
    "\n",
    "        self.gamma = nn.Parameter(torch.ones(shape))\n",
    "        self.beta  = nn.Parameter(torch.zeros(shape))\n",
    "\n",
    "        self.moving_mean = torch.zeros(shape)\n",
    "        self.moving_var = torch.zeros(shape)\n",
    "\n",
    "    def forward(self,x):\n",
    "        if self.moving_var.device != x.device:\n",
    "            self.moving_var = self.moving_var.to(x.device)\n",
    "            self.moving_mean = self.moving_mean.to(x.device)\n",
    "\n",
    "        y,self.moving_mean,self.moving_var = batch_norm(self.training,x,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)\n",
    "\n",
    "        return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "training on  cuda\nepoch 1, loss 0.9969, train acc 0.789,test acc 0.821,time 3.8\nepoch 2, loss 0.2267, train acc 0.865,test acc 0.837,time 2.6\nepoch 3, loss 0.1224, train acc 0.878,test acc 0.842,time 2.6\nepoch 4, loss 0.0833, train acc 0.885,test acc 0.876,time 2.7\nepoch 5, loss 0.0625, train acc 0.891,test acc 0.833,time 2.6\n"
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1,6,5),\n",
    "    BatchNorm(6,num_dims=4),\n",
    "    nn.Sigmoid(),\n",
    "    nn.MaxPool2d(2,2),\n",
    "    nn.Conv2d(6,16,5),\n",
    "    BatchNorm(16,num_dims=4),\n",
    "    nn.Sigmoid(),\n",
    "    nn.MaxPool2d(2,2),\n",
    "    d2l.FlattenLayer(),\n",
    "    nn.Linear(16 * 4 * 4,120),\n",
    "    BatchNorm(120,num_dims=2),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(120,84),\n",
    "    BatchNorm(84,num_dims=2),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(84,10)\n",
    ")\n",
    "\n",
    "batch_size = 256\n",
    "train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "\n",
    "lr,num_epochs = 0.001,5\n",
    "optimizer = torch.optim.Adam(net.parameters(),lr=lr)\n",
    "d2l.train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "training on  cuda\nepoch 1, loss 1.3784, train acc 0.764,test acc 0.805,time 2.0\nepoch 2, loss 0.3021, train acc 0.858,test acc 0.818,time 1.9\nepoch 3, loss 0.1377, train acc 0.877,test acc 0.686,time 1.9\nepoch 4, loss 0.0878, train acc 0.886,test acc 0.850,time 1.9\nepoch 5, loss 0.0647, train acc 0.891,test acc 0.786,time 1.9\n"
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Conv2d(1,6,5),\n",
    "    nn.BatchNorm2d(6),\n",
    "    nn.Sigmoid(),\n",
    "    nn.MaxPool2d(2,2),\n",
    "    nn.Conv2d(6,16,5),\n",
    "    nn.BatchNorm2d(16),\n",
    "    nn.Sigmoid(),\n",
    "    nn.MaxPool2d(2,2),\n",
    "    d2l.FlattenLayer(),\n",
    "    nn.Linear(16 * 4 * 4,120),\n",
    "    nn.BatchNorm1d(120),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(120,84),\n",
    "    nn.BatchNorm1d(84),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(84,10)\n",
    ")\n",
    "\n",
    "batch_size = 256\n",
    "train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)\n",
    "\n",
    "lr,num_epochs = 0.001,5\n",
    "optimizer = torch.optim.Adam(net.parameters(),lr=lr)\n",
    "d2l.train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitpytorchnotebookconda6e7a8693df0d4d92aca09d521275d23a",
   "display_name": "Python 3.7.7 64-bit ('pytorch_notebook': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}